{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets,transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#超参数\n",
    "batch_size = 200\n",
    "epochs = 10\n",
    "learning_rate = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data',train=True,\n",
    "                    transform=transforms.Compose([transforms.ToTensor(),\n",
    "                    transforms.Normalize((0.1307,),(0.3081,))])),\n",
    "    batch_size = batch_size,shuffle = True)\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data',train=False,\n",
    "                    transform=transforms.Compose([transforms.ToTensor(),\n",
    "                    transforms.Normalize((0.1307,),(0.3081,))])),\n",
    "    batch_size = batch_size,shuffle = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "w1,b1 = torch.randn(200,784,requires_grad=True),torch.zeros(200,requires_grad=True)\n",
    "w2,b2 = torch.randn(200,200,requires_grad=True),torch.zeros(200,requires_grad=True)\n",
    "w3,b3 = torch.randn(10,200,requires_grad=True),torch.zeros(10,requires_grad=True)\n",
    "\n",
    "nn.init.kaiming_normal_(w1)\n",
    "nn.init.kaiming_normal_(w2)\n",
    "nn.init.kaiming_normal_(w3)\n",
    "\n",
    "def forward(x):\n",
    "    X = torch.relu(x@w1.t()+b1)\n",
    "    X = torch.relu(X@w2.t()+b2)\n",
    "    X = torch.relu(X@w3.t()+b3)\n",
    "    return X\n",
    "\n",
    "optimizer = optim.Adam([w1,b1,w2,b2,w3,b3],lr=learning_rate)\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train epoch: 0 [0/60000 (0%)] \t Loss: 0.297809\n",
      "Train epoch: 0 [20000/60000 (33%)] \t Loss: 0.250234\n",
      "Train epoch: 0 [40000/60000 (67%)] \t Loss: 0.198960\n",
      "\n",
      "Test set: Average loss: 0.0015, Accuracy: 8853/10000 (88%)\n",
      "\n",
      "Train epoch: 1 [0/60000 (0%)] \t Loss: 0.153826\n",
      "Train epoch: 1 [20000/60000 (33%)] \t Loss: 0.302364\n",
      "Train epoch: 1 [40000/60000 (67%)] \t Loss: 0.266285\n",
      "\n",
      "Test set: Average loss: 0.0015, Accuracy: 8853/10000 (88%)\n",
      "\n",
      "Train epoch: 2 [0/60000 (0%)] \t Loss: 0.104403\n",
      "Train epoch: 2 [20000/60000 (33%)] \t Loss: 0.280470\n",
      "Train epoch: 2 [40000/60000 (67%)] \t Loss: 0.210796\n",
      "\n",
      "Test set: Average loss: 0.0015, Accuracy: 8845/10000 (88%)\n",
      "\n",
      "Train epoch: 3 [0/60000 (0%)] \t Loss: 0.295044\n",
      "Train epoch: 3 [20000/60000 (33%)] \t Loss: 0.203187\n",
      "Train epoch: 3 [40000/60000 (67%)] \t Loss: 0.142824\n",
      "\n",
      "Test set: Average loss: 0.0016, Accuracy: 8858/10000 (88%)\n",
      "\n",
      "Train epoch: 4 [0/60000 (0%)] \t Loss: 0.244551\n",
      "Train epoch: 4 [20000/60000 (33%)] \t Loss: 0.189632\n",
      "Train epoch: 4 [40000/60000 (67%)] \t Loss: 0.235048\n",
      "\n",
      "Test set: Average loss: 0.0016, Accuracy: 8850/10000 (88%)\n",
      "\n",
      "Train epoch: 5 [0/60000 (0%)] \t Loss: 0.303495\n",
      "Train epoch: 5 [20000/60000 (33%)] \t Loss: 0.242201\n",
      "Train epoch: 5 [40000/60000 (67%)] \t Loss: 0.277041\n",
      "\n",
      "Test set: Average loss: 0.0016, Accuracy: 8823/10000 (88%)\n",
      "\n",
      "Train epoch: 6 [0/60000 (0%)] \t Loss: 0.277767\n",
      "Train epoch: 6 [20000/60000 (33%)] \t Loss: 0.214068\n",
      "Train epoch: 6 [40000/60000 (67%)] \t Loss: 0.187285\n",
      "\n",
      "Test set: Average loss: 0.0016, Accuracy: 8860/10000 (88%)\n",
      "\n",
      "Train epoch: 7 [0/60000 (0%)] \t Loss: 0.249266\n",
      "Train epoch: 7 [20000/60000 (33%)] \t Loss: 0.301418\n",
      "Train epoch: 7 [40000/60000 (67%)] \t Loss: 0.308547\n",
      "\n",
      "Test set: Average loss: 0.0016, Accuracy: 8860/10000 (88%)\n",
      "\n",
      "Train epoch: 8 [0/60000 (0%)] \t Loss: 0.252637\n",
      "Train epoch: 8 [20000/60000 (33%)] \t Loss: 0.165470\n",
      "Train epoch: 8 [40000/60000 (67%)] \t Loss: 0.252666\n",
      "\n",
      "Test set: Average loss: 0.0016, Accuracy: 8863/10000 (88%)\n",
      "\n",
      "Train epoch: 9 [0/60000 (0%)] \t Loss: 0.253619\n",
      "Train epoch: 9 [20000/60000 (33%)] \t Loss: 0.221383\n",
      "Train epoch: 9 [40000/60000 (67%)] \t Loss: 0.210254\n",
      "\n",
      "Test set: Average loss: 0.0016, Accuracy: 8848/10000 (88%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(epochs):\n",
    "    for idx,(train,target) in enumerate(train_loader):\n",
    "        train = train.view(-1,28*28)\n",
    "        logits = forward(train)\n",
    "        loss = criterion(logits,target)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if idx % 100 == 0:\n",
    "            print('Train epoch: {} [{}/{} ({:.0f}%)] \\t Loss: {:.6f}'\n",
    "                  .format(epoch,idx*len(train),len(train_loader.dataset),100. * idx / len(train_loader),loss.item()))\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    for data,target in test_loader:\n",
    "        data = data.view(-1,28*28)\n",
    "        logits = forward(data)\n",
    "        test_loss += criterion(logits,target).item()\n",
    "        predict_data = logits.data.max(1)[1]\n",
    "        correct += predict_data.eq(target).sum()\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n",
      "torch.Size([200, 784])\n"
     ]
    }
   ],
   "source": [
    "for data,target in test_loader:\n",
    "    data = data.view(-1,28*28)\n",
    "    print(data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
